library(redist)
library(tidyverse)
library(maptools)
library(doMC)
library(igraph)
library(parallel)
source("random_submap.R")

ben_theme <- function(){
    theme_classic() +
        theme(panel.background = element_blank(),
              panel.grid.major = element_blank(),
              panel.grid.minor = element_blank(),
              axis.line = element_line(colour = "black"),
              panel.border = element_rect(colour = "black", fill = NA, size = 1),
              strip.background = element_blank(),
              legend.position = "bottom", legend.title = element_blank(),
              plot.title = element_text(hjust = 0.5))
}
ipw <- function(x, beta, pop){
    indpop <- which(x$distance_parity <= pop)
    indbeta <- which(x$beta_sequence == beta)
    inds <- intersect(indpop, indbeta)
    psi <- x$constraint_pop[inds]
    w <- 1 / exp(beta * psi)
    inds <- sample(inds, length(inds), replace = TRUE, prob = w)
    x <- x$partitions[,inds]
    return(x)
}

## Load map
nh <- readShapePoly("../data/nh/nh_final.shp")

## Set parameters
params <- expand.grid(
    nprecs = 25,
    ndists = 2,
    nsims = 1:200
)

appx_sims <- 10000000

dir.create("../data/qq_test_scratch", showWarnings = FALSE)

## --------------
## Loop over sims
## --------------
nsims <- 25
registerDoMC(detectCores()-2)
ks_out <- foreach(i = 1:nsims, .combine = "c") %dopar% {

    ## Get parameters
    nprecs <- params$nprecs[i]
    ndists <- params$ndists[i]
    alg <- params$alg[i]

    ## Sample map
    nh_sub <- random_submap(nh, nprecs)
    nh_sub@data$POP100 <- nh_sub@data$POP100 + 1
    nh_sub@data$PRES_RVOTE <- nh_sub@data$PRES_RVOTE + 1

    ## Convert shp file to adjacency list
    adjlist <- poly2nb(nh_sub, queen = FALSE)

    ## Sink
    adjlist_map <- c()
    for(k in 1:length(adjlist)){
        sub <- adjlist[[k]]
        sub <- sub[sub > k]
        if(length(sub) > 0){
            for(l in 1:length(sub)){
                adjlist_map <- rbind(adjlist_map, c(k, sub[l]))
            }
        }
    }
    write_delim(data.frame(adjlist_map),
                path = paste0("../data/qq_test_scratch/adjlist_map_", i, ".dat"),
                col_names = FALSE)

    ## Order edges
    system(paste0("python ndscut.py <../data/qq_test_scratch/adjlist_map_", i, ".dat >../data/qq_test_scratch/adjlist_map_ordered_", i, ".dat"))

    ## Run enumpart
    system(paste0("../enumpart_private/enumpart ../data/qq_test_scratch/adjlist_map_ordered_", i,
                  ".dat -k ", ndists, " -allsols >../data/qq_test_scratch/adjlist_map_sols_", i, ".dat"))

    ## Load solutions, get true distribution
    sols <- readLines(paste0("../data/qq_test_scratch/adjlist_map_sols_", i, ".dat"))
    el <- strsplit(readLines(paste0("../data/qq_test_scratch/adjlist_map_ordered_", i, ".dat")), split = " ")
    el <- apply(do.call(rbind, el), 2, as.numeric)
    sols_out <- mclapply(1:length(sols), function(x){
        sol_split <- as.numeric(strsplit(sols[x], split = " ")[[1]])
        el_sub <- el[sol_split,]
        comps_out <- components(graph_from_edgelist(el_sub, directed = FALSE))
        if(comps_out$no < ndists){
            max_num <- max(comps_out$membership)
            comps_out <- c(comps_out$membership, (max_num + 1):ndists)
        }else{
            comps_out <- comps_out$membership
        }
        return(comps_out)
    }, mc.cores = detectCores())
    sols_out <- do.call(cbind, sols_out)

    target <- redist.segcalc(sols_out, grouppop = nh_sub@data$PRES_RVOTE, 
                             fullpop = nh_sub@data$POP100)
    popdist <- apply(sols_out, 2, function(x){
        distpop <- tapply(nh_sub@data$POP100, x, sum)
        parpop <- sum(distpop) / length(distpop)
        max(abs(distpop / parpop - 1))
    })

    ## --------------------
    ## Run algorithm - MCMC
    ## --------------------
    thin_chain <- function(x, thin = 1000){
        inds <- seq(1, ncol(x$partitions), by = thin)
        x_new <- vector(mode = "list", length = length(x))
        for(i in 1:length(x)){

            ## Subset the matrix first, then the vectors
            if(i == 1){
                x_new[[i]] <- x[[i]][,inds]
            }else{
                x_new[[i]] <- x[[i]][inds]
            }
            
        }
        names(x_new) <- names(x)
        class(x_new) <- "redist"
        return(x_new)
    }

    ## No parity
    betaweights <- rep(NA, 10); for(i in 1:10){betaweights[i] <- 2^i}
    mcmc_out <- redist.mcmc(nh_sub, nh_sub@data$POP100, nsims = appx_sims,
                            ndists = ndists, constraint = "population",
                            ## temper = "simulated",
                            ## betaweights = betaweights,
                            maxiterrsg = 500000,
                            beta = -10)
    mcmc_out <- ipw(thin_chain(mcmc_out), beta = -10, pop = .1)
    mcmc_seg_10p <- redist.segcalc(mcmc_out, grouppop = nh_sub@data$PRES_RVOTE, 
                                   fullpop = nh_sub@data$POP100)
    
    return(ks.test(target[popdist <= .1], mcmc_seg_10p)$p.value)
}

ggplot(data.frame(ks_pval = ks_out),
       aes(sample = ks_pval)) + 
    stat_qq(distribution = stats::qunif) +
    geom_abline(aes(intercept = 0, slope = 1), lty = "dashed") 
